{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Few-Shot Learning With Prototypical Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['omniglot']\n"
     ]
    }
   ],
   "source": [
    "# This Python 3 environment comes with many helpful analytics libraries installed\n",
    "# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python\n",
    "# For example, here's several helpful packages to load in \n",
    "\n",
    "import numpy as np # linear algebra\n",
    "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
    "from matplotlib import pyplot as plt\n",
    "import cv2\n",
    "from tqdm import tqdm\n",
    "import multiprocessing as mp\n",
    "tqdm.pandas(desc=\"my bar!\")\n",
    "# Input data files are available in the \"../input/\" directory.\n",
    "# For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory\n",
    "\n",
    "import os\n",
    "print(os.listdir(\"../input\"))\n",
    "\n",
    "# Any results you write to the current directory are saved as output."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Reading and Augmentation\n",
    "The Omniglot data set is designed for developing more human-like learning algorithms. It contains 1623 different handwritten characters from 50 different alphabets. Then to increase the number of classes, all the images are rotated by 90, 180 and 270 degrees and each rotation resulted in one more class. Hence the total count of classes reached to 6492(1623 * 4) classes. We split images of 4200 classes to training data and the rest went to test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
    "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
   },
   "outputs": [],
   "source": [
    "train_dir = os.listdir('../input/omniglot/images_background/')\n",
    "datax = np.array([])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def image_rotate(img, angle):\n",
    "    \"\"\"\n",
    "    Image rotation at certain angle. It is used for data augmentation \n",
    "    \"\"\"\n",
    "    rows,cols, _ = img.shape\n",
    "    M = cv2.getRotationMatrix2D((cols/2 ,rows/2),angle,1)\n",
    "    dst = cv2.warpAffine(img,M,(cols,rows))\n",
    "    return np.expand_dims(dst, 0)\n",
    "\n",
    "def read_alphabets(alphabet_directory, directory):\n",
    "    \"\"\"\n",
    "    Reads all the characters from alphabet_directory and augment each image with 90, 180, 270 degrees of rotation.\n",
    "    \"\"\"\n",
    "    datax = None\n",
    "    datay = []\n",
    "    characters = os.listdir(alphabet_directory)\n",
    "    for character in characters:\n",
    "        images = os.listdir(alphabet_directory + character + '/')\n",
    "        for img in images:\n",
    "            image = cv2.resize(cv2.imread(alphabet_directory + character + '/' + img), (28,28))\n",
    "            image90 = image_rotate(image, 90)\n",
    "            image180 = image_rotate(image, 180)\n",
    "            image270 = image_rotate(image, 270)\n",
    "            image = np.expand_dims(image, 0)\n",
    "            if datax is None:\n",
    "                datax = np.vstack([image, image90, image180, image270])\n",
    "            else:\n",
    "                datax = np.vstack([datax, image, image90, image180, image270])\n",
    "            datay.append(directory + '_' + character + '_0')\n",
    "            datay.append(directory + '_' + character + '_90')\n",
    "            datay.append(directory + '_' + character + '_180')\n",
    "            datay.append(directory + '_' + character + '_270')\n",
    "    return datax, np.array(datay)\n",
    "\n",
    "def read_images(base_directory):\n",
    "    \"\"\"\n",
    "    Used multithreading for data reading to decrease the reading time drastically\n",
    "    \"\"\"\n",
    "    datax = None\n",
    "    datay = []\n",
    "    pool = mp.Pool(mp.cpu_count())\n",
    "    results = [pool.apply(read_alphabets, args=(base_directory + '/' + directory + '/', directory, )) for directory in os.listdir(base_directory)]\n",
    "    pool.close()\n",
    "    for result in results:\n",
    "        if datax is None:\n",
    "            datax = result[0]\n",
    "            datay = result[1]\n",
    "        else:\n",
    "            datax = np.vstack([datax, result[0]])\n",
    "            datay = np.concatenate([datay, result[1]])\n",
    "    return datax, datay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 680 ms, sys: 1.64 s, total: 2.32 s\n",
      "Wall time: 1min 21s\n"
     ]
    }
   ],
   "source": [
    "%time trainx, trainy = read_images('../input/omniglot/images_background/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 336 ms, sys: 844 ms, total: 1.18 s\n",
      "Wall time: 53.7 s\n"
     ]
    }
   ],
   "source": [
    "%time testx, testy = read_images('../input/omniglot/images_evaluation/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((77120, 28, 28, 3), (77120,), (52720, 28, 28, 3), (52720,))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainx.shape, trainy.shape, testx.shape, testy.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from tqdm import trange\n",
    "from time import sleep\n",
    "from sklearn.preprocessing import OneHotEncoder, LabelEncoder\n",
    "use_gpu = torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([77120, 28, 28, 3]), torch.Size([52720, 28, 28, 3]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainx = torch.from_numpy(trainx).float()\n",
    "#trainy = torch.from_numpy(trainy)\n",
    "testx = torch.from_numpy(testx).float()\n",
    "#testy = torch.from_numpy(testy)\n",
    "if use_gpu:\n",
    "    trainx = trainx.cuda()\n",
    "    testx = testx.cuda()\n",
    "trainx.size(), testx.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainx = trainx.permute(0,3,1,2)\n",
    "testx = testx.permute(0,3,1,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    \"\"\"\n",
    "    Image2Vector CNN which takes image of dimension (28x28x3) and return column vector length 64\n",
    "    \"\"\"\n",
    "    def sub_block(self, in_channels, out_channels=64, kernel_size=3):\n",
    "        block = torch.nn.Sequential(\n",
    "                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=1),\n",
    "                    torch.nn.BatchNorm2d(out_channels),\n",
    "                    torch.nn.ReLU(),\n",
    "                    torch.nn.MaxPool2d(kernel_size=2)\n",
    "                )\n",
    "        return block\n",
    "    \n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.convnet1 = self.sub_block(3)\n",
    "        self.convnet2 = self.sub_block(64)\n",
    "        self.convnet3 = self.sub_block(64)\n",
    "        self.convnet4 = self.sub_block(64)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.convnet1(x)\n",
    "        x = self.convnet2(x)\n",
    "        x = self.convnet3(x)\n",
    "        x = self.convnet4(x)\n",
    "        x = torch.flatten(x, start_dim=1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrototypicalNet(nn.Module):\n",
    "    def __init__(self, use_gpu=False):\n",
    "        super(PrototypicalNet, self).__init__()\n",
    "        self.f = Net()\n",
    "        self.gpu = use_gpu\n",
    "        if self.gpu:\n",
    "            self.f = self.f.cuda()\n",
    "    \n",
    "    def forward(self, datax, datay, Ns,Nc, Nq, total_classes):\n",
    "        \"\"\"\n",
    "        Implementation of one episode in Prototypical Net\n",
    "        datax: Training images\n",
    "        datay: Corresponding labels of datax\n",
    "        Nc: Number  of classes per episode\n",
    "        Ns: Number of support data per class\n",
    "        Nq:  Number of query data per class\n",
    "        total_classes: Total classes in training set\n",
    "        \"\"\"\n",
    "        k = total_classes.shape[0]\n",
    "        K = np.random.choice(total_classes, Nc, replace=False)\n",
    "        Query_x = torch.Tensor()\n",
    "        if(self.gpu):\n",
    "            Query_x = Query_x.cuda()\n",
    "        Query_y = []\n",
    "        Query_y_count = []\n",
    "        centroid_per_class  = {}\n",
    "        class_label = {}\n",
    "        label_encoding = 0\n",
    "        for cls in K:\n",
    "            S_cls, Q_cls = self.random_sample_cls(datax, datay, Ns, Nq, cls)\n",
    "            centroid_per_class[cls] = self.get_centroid(S_cls, Nc)\n",
    "            class_label[cls] = label_encoding\n",
    "            label_encoding += 1\n",
    "            Query_x = torch.cat((Query_x, Q_cls), 0) # Joining all the query set together\n",
    "            Query_y += [cls]\n",
    "            Query_y_count += [Q_cls.shape[0]]\n",
    "        Query_y, Query_y_labels = self.get_query_y(Query_y, Query_y_count, class_label)\n",
    "        Query_x = self.get_query_x(Query_x, centroid_per_class, Query_y_labels)\n",
    "        return Query_x, Query_y\n",
    "    \n",
    "    def random_sample_cls(self, datax, datay, Ns, Nq, cls):\n",
    "        \"\"\"\n",
    "        Randomly samples Ns examples as support set and Nq as Query set\n",
    "        \"\"\"\n",
    "        data = datax[(datay == cls).nonzero()]\n",
    "        perm = torch.randperm(data.shape[0])\n",
    "        idx = perm[:Ns]\n",
    "        S_cls = data[idx]\n",
    "        idx = perm[Ns : Ns+Nq]\n",
    "        Q_cls = data[idx]\n",
    "        if self.gpu:\n",
    "            S_cls = S_cls.cuda()\n",
    "            Q_cls = Q_cls.cuda()\n",
    "        return S_cls, Q_cls\n",
    "    \n",
    "    def get_centroid(self, S_cls, Nc):\n",
    "        \"\"\"\n",
    "        Returns a centroid vector of support set for a class\n",
    "        \"\"\"\n",
    "        return torch.sum(self.f(S_cls), 0).unsqueeze(1).transpose(0,1) / Nc\n",
    "    \n",
    "    def get_query_y(self, Qy, Qyc, class_label):\n",
    "        \"\"\"\n",
    "        Returns labeled representation of classes of Query set and a list of labels.\n",
    "        \"\"\"\n",
    "        labels = []\n",
    "        m = len(Qy)\n",
    "        for i in range(m):\n",
    "            labels += [Qy[i]] * Qyc[i]\n",
    "        labels = np.array(labels).reshape(len(labels), 1)\n",
    "        label_encoder = LabelEncoder()\n",
    "        Query_y = torch.Tensor(label_encoder.fit_transform(labels).astype(int)).long()\n",
    "        if self.gpu:\n",
    "            Query_y = Query_y.cuda()\n",
    "        Query_y_labels = np.unique(labels)\n",
    "        return Query_y, Query_y_labels\n",
    "    \n",
    "    def get_centroid_matrix(self, centroid_per_class, Query_y_labels):\n",
    "        \"\"\"\n",
    "        Returns the centroid matrix where each column is a centroid of a class.\n",
    "        \"\"\"\n",
    "        centroid_matrix = torch.Tensor()\n",
    "        if(self.gpu):\n",
    "            centroid_matrix = centroid_matrix.cuda()\n",
    "        for label in Query_y_labels:\n",
    "            centroid_matrix = torch.cat((centroid_matrix, centroid_per_class[label]))\n",
    "        if self.gpu:\n",
    "            centroid_matrix = centroid_matrix.cuda()\n",
    "        return centroid_matrix\n",
    "    \n",
    "    def get_query_x(self, Query_x, centroid_per_class, Query_y_labels):\n",
    "        \"\"\"\n",
    "        Returns distance matrix from each Query image to each centroid.\n",
    "        \"\"\"\n",
    "        centroid_matrix = self.get_centroid_matrix(centroid_per_class, Query_y_labels)\n",
    "        Query_x = self.f(Query_x)\n",
    "        m = Query_x.size(0)\n",
    "        n = centroid_matrix.size(0)\n",
    "        # The below expressions expand both the matrices such that they become compatible to each other in order to caclulate L2 distance.\n",
    "        centroid_matrix = centroid_matrix.expand(m, centroid_matrix.size(0), centroid_matrix.size(1)) # Expanding centroid matrix to \"m\".\n",
    "        Query_matrix = Query_x.expand(n, Query_x.size(0), Query_x.size(1)).transpose(0,1) # Expanding Query matrix \"n\" times\n",
    "        Qx = torch.pairwise_distance(centroid_matrix.transpose(1,2), Query_matrix.transpose(1,2))\n",
    "        return Qx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "protonet = PrototypicalNet(use_gpu=use_gpu)\n",
    "optimizer = optim.SGD(protonet.parameters(), lr = 0.01, momentum=0.99)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_step(datax, datay, Ns,Nc, Nq):\n",
    "    optimizer.zero_grad()\n",
    "    Qx, Qy= protonet(datax, datay, Ns, Nc, Nq, np.unique(datay))\n",
    "    pred = torch.log_softmax(Qx, dim=-1)\n",
    "    loss = F.nll_loss(pred, Qy)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    acc = torch.mean((torch.argmax(pred, 1) == Qy).float())\n",
    "    return loss, acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_episode = 16000\n",
    "frame_size = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.6/site-packages/sklearn/preprocessing/label.py:235: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
      "  y = column_or_1d(y, warn=True)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Frame Number: 1 Frame Loss:  3.871866455078125 Frame Accuracy: 3.6483432769775392\n",
      "Frame Number: 2 Frame Loss:  3.152219970703125 Frame Accuracy: 10.43133316040039\n",
      "Frame Number: 3 Frame Loss:  2.39058642578125 Frame Accuracy: 24.33332977294922\n",
      "Frame Number: 4 Frame Loss:  1.94584033203125 Frame Accuracy: 39.45198974609375\n",
      "Frame Number: 5 Frame Loss:  1.5654544677734374 Frame Accuracy: 55.72529296875\n",
      "Frame Number: 6 Frame Loss:  1.242635009765625 Frame Accuracy: 68.28630981445312\n",
      "Frame Number: 7 Frame Loss:  0.9735999145507812 Frame Accuracy: 78.3112548828125\n",
      "Frame Number: 8 Frame Loss:  0.770054931640625 Frame Accuracy: 83.34801025390625\n",
      "Frame Number: 9 Frame Loss:  0.6496942138671875 Frame Accuracy: 86.1826904296875\n",
      "Frame Number: 10 Frame Loss:  0.5608416137695312 Frame Accuracy: 88.241015625\n",
      "Frame Number: 11 Frame Loss:  0.4988248291015625 Frame Accuracy: 89.77072143554688\n",
      "Frame Number: 12 Frame Loss:  0.4493779296875 Frame Accuracy: 90.85570678710937\n",
      "Frame Number: 13 Frame Loss:  0.40475567626953124 Frame Accuracy: 91.9748046875\n",
      "Frame Number: 14 Frame Loss:  0.37008343505859376 Frame Accuracy: 92.7223876953125\n",
      "Frame Number: 15 Frame Loss:  0.3400647888183594 Frame Accuracy: 93.38535766601562\n",
      "Frame Number: 16 Frame Loss:  0.318818359375 Frame Accuracy: 93.79469604492188\n"
     ]
    }
   ],
   "source": [
    "frame_loss = 0\n",
    "frame_acc = 0\n",
    "for i in range(num_episode):\n",
    "    loss, acc = train_step(trainx, trainy, 5, 60, 5)\n",
    "    frame_loss += loss.data\n",
    "    frame_acc += acc.data\n",
    "    if( (i+1) % frame_size == 0):\n",
    "        print(\"Frame Number:\", ((i+1) // frame_size), 'Frame Loss: ', frame_loss.data.cpu().numpy().tolist()/ frame_size, 'Frame Accuracy:', (frame_acc.data.cpu().numpy().tolist() * 100) / frame_size)\n",
    "        frame_loss = 0\n",
    "        frame_acc = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_step(datax, datay, Ns,Nc, Nq):\n",
    "    Qx, Qy= protonet(datax, datay, Ns, Nc, Nq, np.unique(datay))\n",
    "    pred = torch.log_softmax(Qx, dim=-1)\n",
    "    loss = F.nll_loss(pred, Qy)\n",
    "    acc = torch.mean((torch.argmax(pred, 1) == Qy).float())\n",
    "    return loss, acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_test_episode = 2000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.6/site-packages/sklearn/preprocessing/label.py:235: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
      "  y = column_or_1d(y, warn=True)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Avg Loss:  0.5742660522460937 Avg Accuracy: 84.63178100585938\n"
     ]
    }
   ],
   "source": [
    "avg_loss = 0\n",
    "avg_acc = 0\n",
    "for _ in range(num_test_episode):\n",
    "    loss, acc = test_step(testx, testy, 5, 60, 15)\n",
    "    avg_loss += loss.data\n",
    "    avg_acc += acc.data\n",
    "print('Avg Loss: ', avg_loss.data.cpu().numpy().tolist() / num_test_episode , 'Avg Accuracy:', (avg_acc.data.cpu().numpy().tolist() * 100) / num_test_episode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
